import asyncio
from typing import (
    Literal,
    cast,
)

import pytest
from aiortc import (
    AudioStreamTrack,
    RTCPeerConnection,
    RTCSessionDescription,
    VideoStreamTrack,
)
from fastapi import FastAPI
from fastapi.testclient import TestClient
from fastrtc.stream import Body
from fastrtc.tracks import HandlerType
from fastrtc.webrtc_connection_mixin import WebRTCConnectionMixin


class MinimalTestStream(WebRTCConnectionMixin):
    def __init__(
        self,
        handler: HandlerType = lambda x: x,
        *,
        mode: Literal["send-receive", "receive", "send"] = "send-receive",
        modality: Literal["video", "audio", "audio-video"] = "video",
        concurrency_limit: int | None | Literal["default"] = "default",
        time_limit: float | None = None,
        allow_extra_tracks: bool = False,
    ):
        WebRTCConnectionMixin.__init__(self)
        self.mode = mode
        self.modality = modality
        self.event_handler = handler
        self.concurrency_limit = cast(
            (int),
            1 if concurrency_limit in ["default", None] else concurrency_limit,
        )
        self.time_limit = time_limit
        self.allow_extra_tracks = allow_extra_tracks
        self.server_rtc_configuration = None

    def mount(self, app: FastAPI, path: str = ""):
        from fastapi import APIRouter

        router = APIRouter(prefix=path)
        router.post("/webrtc/offer")(self.offer)
        app.include_router(router)

    async def offer(self, body: Body):
        return await self.handle_offer(
            body.model_dump(), set_outputs=self.set_additional_outputs(body.webrtc_id)
        )


@pytest.fixture()
def test_client_and_stream(request):
    app = FastAPI()
    params = request.param if hasattr(request, "param") else {}
    stream = MinimalTestStream(**params)
    stream.mount(app)
    test_client = TestClient(app)
    yield test_client, stream


class TestWebRTCConnectionMixin:
    @staticmethod
    async def setup_peer_connection(audio=False, video=False):
        pc = RTCPeerConnection()
        channel = pc.createDataChannel("test-data-channel")
        if audio:
            audio_track = AudioStreamTrack()
            pc.addTrack(audio_track)
        if video:
            video_track = VideoStreamTrack()
            pc.addTrack(video_track)

        await pc.setLocalDescription(await pc.createOffer())
        return pc, channel

    @staticmethod
    async def send_offer(
        pc,
        client,
        audio=False,
        video=False,
        webrtc_id="test_id",
        response_code=200,
        return_status_and_metadata=False,
    ):
        body = {
            "sdp": pc.localDescription.sdp,
            "type": pc.localDescription.type,
        }
        if webrtc_id is not None:
            body["webrtc_id"] = webrtc_id
        response = client.post(
            "/webrtc/offer",
            headers={"Content-Type": "application/json"},
            json=body,
        )
        assert response.status_code == response_code
        if not response_code == 200:
            return
        out = response.json()
        if return_status_and_metadata:
            assert "status" in out and "meta" in out
            return out["status"], out["meta"]
        assert "type" in out and out["type"] == "answer"
        assert "webrtc-datachannel" in out["sdp"]
        if audio:
            assert "m=audio" in out["sdp"]
        if video:
            assert "m=video" in out["sdp"]

        await pc.setRemoteDescription(RTCSessionDescription(out["sdp"], out["type"]))

        # Allow data to stream
        await asyncio.sleep(0.5)

    @staticmethod
    async def close_peer_connection(pc):
        await pc.close()
        assert pc.connectionState == "closed"
        assert pc.iceConnectionState == "closed"
        assert pc.signalingState == "closed"

    @pytest.mark.asyncio
    @pytest.mark.parametrize(
        "test_client_and_stream", [{"modality": "audio"}], indirect=True
    )
    @pytest.mark.parametrize("audio", [True, False])
    async def test_successful_connection_audio(self, test_client_and_stream, audio):
        test_client, stream = test_client_and_stream
        pc, channel = await self.setup_peer_connection(audio)
        await self.send_offer(pc, test_client, audio)
        # TODO: Test stream? E.g., when no audio is not part of the offer...
        await self.close_peer_connection(pc)

    @pytest.mark.asyncio
    @pytest.mark.parametrize(
        "test_client_and_stream", [{"modality": "video"}], indirect=True
    )
    @pytest.mark.parametrize("video", [True, False])
    async def test_successful_connection_video(self, test_client_and_stream, video):
        test_client, stream = test_client_and_stream
        pc, channel = await self.setup_peer_connection(video=video)
        await self.send_offer(pc, test_client, video=video)
        # TODO: Test stream? E.g., when no video is not part of the offer...
        await self.close_peer_connection(pc)

    @pytest.mark.asyncio
    @pytest.mark.parametrize(
        "test_client_and_stream", [{"modality": "audio"}], indirect=True
    )
    @pytest.mark.parametrize("audio", [True, False])
    async def test_unsuccessful_connection_audio(self, test_client_and_stream, audio):
        test_client, stream = test_client_and_stream
        pc, channel = await self.setup_peer_connection(audio=audio, video=True)
        with pytest.raises(ValueError, match=r"Unsupported track kind .*"):
            await self.send_offer(pc, test_client, audio=audio, video=True)
        await self.close_peer_connection(pc)

    @pytest.mark.asyncio
    @pytest.mark.parametrize(
        "test_client_and_stream", [{"modality": "video"}], indirect=True
    )
    @pytest.mark.parametrize("video", [True, False])
    async def test_unsuccessful_connection_video(self, test_client_and_stream, video):
        test_client, stream = test_client_and_stream
        pc, channel = await self.setup_peer_connection(audio=True, video=video)
        with pytest.raises(ValueError, match=r"Unsupported track kind .*"):
            await self.send_offer(pc, test_client, audio=True, video=video)
        await self.close_peer_connection(pc)

    @pytest.mark.asyncio
    async def test_unsuccessful_webrtc_offer_no_webrtc_id(self, test_client_and_stream):
        test_client, stream = test_client_and_stream
        pc, channel = await self.setup_peer_connection()
        await self.send_offer(
            pc,
            test_client,
            webrtc_id=None,
            response_code=422,
        )

    @pytest.mark.asyncio
    @pytest.mark.parametrize(
        "test_client_and_stream", [{"modality": "dummy"}], indirect=True
    )
    @pytest.mark.parametrize(
        "audio, video",
        [
            (True, False),
            (False, True),
            (True, True),
        ],
    )
    async def test_incorrect_modality(self, test_client_and_stream, audio, video):
        test_client, stream = test_client_and_stream
        pc, channel = await self.setup_peer_connection(audio=audio, video=video)
        with pytest.raises(ValueError, match=r"Modality must be .*"):
            await self.send_offer(pc, test_client, audio=audio, video=video)
        await self.close_peer_connection(pc)

    @pytest.mark.asyncio
    @pytest.mark.parametrize(
        "test_client_and_stream", [{"concurrency_limit": 1}], indirect=True
    )
    async def test_concurrency_limit_reached_two_peers(self, test_client_and_stream):
        test_client, stream = test_client_and_stream
        pc1, channel = await self.setup_peer_connection(video=True)
        pc2, channel = await self.setup_peer_connection(video=True)
        await self.send_offer(pc1, test_client)
        status, metadata = await self.send_offer(
            pc2, test_client, return_status_and_metadata=True
        )
        assert status == "failed"
        assert metadata == {"error": "connection_already_exists"}

        await self.close_peer_connection(pc1)
        await self.close_peer_connection(pc2)

    @pytest.mark.asyncio
    @pytest.mark.parametrize(
        "test_client_and_stream", [{"concurrency_limit": 2}], indirect=True
    )
    async def test_concurrency_limit_reached_three_peers_same_id(
        self, test_client_and_stream
    ):
        test_client, stream = test_client_and_stream
        pc1, channel = await self.setup_peer_connection(video=True)
        pc2, channel = await self.setup_peer_connection(video=True)
        pc3, channel = await self.setup_peer_connection(video=True)
        await self.send_offer(pc1, test_client)
        status, metadata = await self.send_offer(
            pc2, test_client, return_status_and_metadata=True
        )
        assert status == "failed"
        assert metadata == {"error": "connection_already_exists"}
        status, metadata = await self.send_offer(
            pc3, test_client, return_status_and_metadata=True
        )
        assert status == "failed"
        assert metadata == {"error": "connection_already_exists"}

        await self.close_peer_connection(pc1)
        await self.close_peer_connection(pc2)
        await self.close_peer_connection(pc3)

    @pytest.mark.asyncio
    @pytest.mark.parametrize(
        "test_client_and_stream", [{"concurrency_limit": 2}], indirect=True
    )
    async def test_concurrency_limit_reached_three_peers(self, test_client_and_stream):
        test_client, stream = test_client_and_stream
        pc1, channel = await self.setup_peer_connection(video=True)
        pc2, channel = await self.setup_peer_connection(video=True)
        pc3, channel = await self.setup_peer_connection(video=True)
        await self.send_offer(pc1, test_client, webrtc_id="foo")
        await self.send_offer(pc2, test_client, webrtc_id="bar")
        status, metadata = await self.send_offer(
            pc3, test_client, webrtc_id="baz", return_status_and_metadata=True
        )

        assert status == "failed"
        assert metadata == {"error": "concurrency_limit_reached", "limit": 2}

        await self.close_peer_connection(pc1)
        await self.close_peer_connection(pc2)
        await self.close_peer_connection(pc3)

    @pytest.mark.asyncio
    @pytest.mark.parametrize(
        "test_client_and_stream", [{"concurrency_limit": 1}], indirect=True
    )
    async def test_concurrency_limit_reached_peers_with_no_mediastreams(
        self, test_client_and_stream
    ):
        test_client, stream = test_client_and_stream
        pc1, channel = await self.setup_peer_connection()
        pc2, channel = await self.setup_peer_connection()
        await self.send_offer(pc1, test_client, webrtc_id="foo")
        status, metadata = await self.send_offer(
            pc2, test_client, webrtc_id="bar", return_status_and_metadata=True
        )
        assert status == "failed"
        assert metadata == {"error": "concurrency_limit_reached", "limit": 1}

        await self.close_peer_connection(pc1)
        await self.close_peer_connection(pc2)

    @pytest.mark.asyncio
    @pytest.mark.parametrize(
        "test_client_and_stream",
        [
            {"allow_extra_tracks": True, "modality": "audio"},
            {"allow_extra_tracks": True, "modality": "video"},
        ],
        indirect=True,
    )
    async def test_successful_connection_allow_extra_tracks(
        self, test_client_and_stream
    ):
        test_client, stream = test_client_and_stream
        pc, channel = await self.setup_peer_connection(audio=True, video=True)
        await self.send_offer(pc, test_client, audio=True, video=True)
        await self.close_peer_connection(pc)
